{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n",
    "- Author: Sebastian Raschka\n",
    "- GitHub Repository: https://github.com/rasbt/deeplearning-models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "vY4SK0xKAJgm"
   },
   "source": [
    "# Model Zoo -- RNN with LSTM with Own Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "sc6xejhY-NzZ"
   },
   "source": [
    "Example notebook showing how to use an own CSV text dataset for training a simple RNN for sentiment classification (here: a binary classification problem with two labels, positive and negative) using LSTM (Long Short Term Memory) cells."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "moNmVfuvnImW"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sebastian Raschka \n",
      "\n",
      "CPython 3.6.8\n",
      "IPython 7.2.0\n",
      "\n",
      "torch 1.0.1.post2\n"
     ]
    }
   ],
   "source": [
    "%load_ext watermark\n",
    "%watermark -a 'Sebastian Raschka' -v -p torch\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torchtext import data\n",
    "from torchtext import datasets\n",
    "import time\n",
    "import random\n",
    "import pandas as pd\n",
    "\n",
    "torch.backends.cudnn.deterministic = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "GSRL42Qgy8I8"
   },
   "source": [
    "## General Settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "OvW1RgfepCBq"
   },
   "outputs": [],
   "source": [
    "RANDOM_SEED = 123\n",
    "torch.manual_seed(RANDOM_SEED)\n",
    "\n",
    "VOCABULARY_SIZE = 20000\n",
    "LEARNING_RATE = 1e-4\n",
    "BATCH_SIZE = 128\n",
    "NUM_EPOCHS = 15\n",
    "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "EMBEDDING_DIM = 128\n",
    "HIDDEN_DIM = 256\n",
    "OUTPUT_DIM = 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "mQMmKUEisW4W"
   },
   "source": [
    "## Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following cells will download the IMDB movie review dataset (http://ai.stanford.edu/~amaas/data/sentiment/) for positive-negative sentiment classification in as CSV-formatted file:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2019-11-28 19:47:46--  https://github.com/rasbt/python-machine-learning-book-2nd-edition/raw/master/code/ch08/movie_data.csv.gz\n",
      "Resolving github.com (github.com)... 140.82.113.3\n",
      "Connecting to github.com (github.com)|140.82.113.3|:443... connected.\n",
      "HTTP request sent, awaiting response... 302 Found\n",
      "Location: https://raw.githubusercontent.com/rasbt/python-machine-learning-book-2nd-edition/master/code/ch08/movie_data.csv.gz [following]\n",
      "--2019-11-28 19:47:46--  https://raw.githubusercontent.com/rasbt/python-machine-learning-book-2nd-edition/master/code/ch08/movie_data.csv.gz\n",
      "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.184.133\n",
      "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.184.133|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 26521894 (25M) [application/octet-stream]\n",
      "Saving to: ‘movie_data.csv.gz’\n",
      "\n",
      "movie_data.csv.gz   100%[===================>]  25.29M  10.5MB/s    in 2.4s    \n",
      "\n",
      "2019-11-28 19:47:49 (10.5 MB/s) - ‘movie_data.csv.gz’ saved [26521894/26521894]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "!wget https://github.com/rasbt/python-machine-learning-book-2nd-edition/raw/master/code/ch08/movie_data.csv.gz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "!gunzip -f movie_data.csv.gz "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Check that the dataset looks okay:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>review</th>\n",
       "      <th>sentiment</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>In 1974, the teenager Martha Moxley (Maggie Gr...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>OK... so... I really like Kris Kristofferson a...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>***SPOILER*** Do not read this, if you think a...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>hi for all the people who have seen this wonde...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>I recently bought the DVD, forgetting just how...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                              review  sentiment\n",
       "0  In 1974, the teenager Martha Moxley (Maggie Gr...          1\n",
       "1  OK... so... I really like Kris Kristofferson a...          0\n",
       "2  ***SPOILER*** Do not read this, if you think a...          0\n",
       "3  hi for all the people who have seen this wonde...          1\n",
       "4  I recently bought the DVD, forgetting just how...          0"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv('movie_data.csv')\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "del df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "4GnH64XvsV8n"
   },
   "source": [
    "Define the Label and Text field formatters:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "TEXT = data.Field(sequential=True,\n",
    "                  tokenize='spacy',\n",
    "                  include_lengths=True) # necessary for packed_padded_sequence\n",
    "\n",
    "LABEL = data.LabelField(dtype=torch.float)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Process the dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "fields = [('review', TEXT), ('sentiment', LABEL)]\n",
    "\n",
    "dataset = data.TabularDataset(\n",
    "    path=\"movie_data.csv\", format='csv',\n",
    "    skip_header=True, fields=fields)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Split the dataset into training, validation, and test partitions:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 68
    },
    "colab_type": "code",
    "id": "WZ_4jiHVnMxN",
    "outputId": "dfa51c04-4845-44c3-f50b-d36d41f132b8"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Num Train: 37500\n",
      "Num Valid: 10000\n",
      "Num Test: 2500\n"
     ]
    }
   ],
   "source": [
    "train_data, valid_data, test_data = dataset.split(\n",
    "    split_ratio=[0.75, 0.05, 0.2],\n",
    "    random_state=random.seed(RANDOM_SEED))\n",
    "\n",
    "print(f'Num Train: {len(train_data)}')\n",
    "print(f'Num Valid: {len(valid_data)}')\n",
    "print(f'Num Test: {len(test_data)}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "L-TBwKWPslPa"
   },
   "source": [
    "Build the vocabulary based on the top \"VOCABULARY_SIZE\" words:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "id": "e8uNrjdtn4A8",
    "outputId": "6cf499d7-7722-4da0-8576-ee0f218cc6e3"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Vocabulary size: 20002\n",
      "Number of classes: 2\n"
     ]
    }
   ],
   "source": [
    "TEXT.build_vocab(train_data, max_size=VOCABULARY_SIZE)\n",
    "LABEL.build_vocab(train_data)\n",
    "\n",
    "print(f'Vocabulary size: {len(TEXT.vocab)}')\n",
    "print(f'Number of classes: {len(LABEL.vocab)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Counter({'0': 18742, '1': 18758})"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "LABEL.vocab.freqs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "JpEMNInXtZsb"
   },
   "source": [
    "The TEXT.vocab dictionary will contain the word counts and indices. The reason why the number of words is VOCABULARY_SIZE + 2 is that it contains to special tokens for padding and unknown words: `<unk>` and `<pad>`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "eIQ_zfKLwjKm"
   },
   "source": [
    "Make dataset iterators:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "i7JiHR1stHNF"
   },
   "outputs": [],
   "source": [
    "train_loader, valid_loader, test_loader = data.BucketIterator.splits(\n",
    "    (train_data, valid_data, test_data), \n",
    "    batch_size=BATCH_SIZE,\n",
    "    sort_within_batch=True, # necessary for packed_padded_sequence\n",
    "    sort_key=lambda x: len(x.review),\n",
    "    device=DEVICE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "R0pT_dMRvicQ"
   },
   "source": [
    "Testing the iterators (note that the number of rows depends on the longest document in the respective batch):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 204
    },
    "colab_type": "code",
    "id": "y8SP_FccutT0",
    "outputId": "fe33763a-4560-4dee-adee-31cc6c48b0b2"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train\n",
      "Text matrix size: torch.Size([512, 128])\n",
      "Target vector size: torch.Size([128])\n",
      "\n",
      "Valid:\n",
      "Text matrix size: torch.Size([52, 128])\n",
      "Target vector size: torch.Size([128])\n",
      "\n",
      "Test:\n",
      "Text matrix size: torch.Size([75, 128])\n",
      "Target vector size: torch.Size([128])\n"
     ]
    }
   ],
   "source": [
    "print('Train')\n",
    "for batch in train_loader:\n",
    "    print(f'Text matrix size: {batch.review[0].size()}')\n",
    "    print(f'Target vector size: {batch.sentiment.size()}')\n",
    "    break\n",
    "    \n",
    "print('\\nValid:')\n",
    "for batch in valid_loader:\n",
    "    print(f'Text matrix size: {batch.review[0].size()}')\n",
    "    print(f'Target vector size: {batch.sentiment.size()}')\n",
    "    break\n",
    "    \n",
    "print('\\nTest:')\n",
    "for batch in test_loader:\n",
    "    print(f'Text matrix size: {batch.review[0].size()}')\n",
    "    print(f'Target vector size: {batch.sentiment.size()}')\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "G_grdW3pxCzz"
   },
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "nQIUm5EjxFNa"
   },
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "\n",
    "class RNN(nn.Module):\n",
    "    def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):\n",
    "        \n",
    "        super().__init__()\n",
    "        \n",
    "        self.embedding = nn.Embedding(input_dim, embedding_dim)\n",
    "        self.rnn = nn.LSTM(embedding_dim, hidden_dim)\n",
    "        self.fc = nn.Linear(hidden_dim, output_dim)\n",
    "        \n",
    "    def forward(self, text, text_length):\n",
    "\n",
    "        #[sentence len, batch size] => [sentence len, batch size, embedding size]\n",
    "        embedded = self.embedding(text)\n",
    "        \n",
    "        packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, text_length)\n",
    "        \n",
    "        #[sentence len, batch size, embedding size] => \n",
    "        #  output: [sentence len, batch size, hidden size]\n",
    "        #  hidden: [1, batch size, hidden size]\n",
    "        packed_output, (hidden, cell) = self.rnn(packed)\n",
    "        \n",
    "        return self.fc(hidden.squeeze(0)).view(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Ik3NF3faxFmZ"
   },
   "outputs": [],
   "source": [
    "INPUT_DIM = len(TEXT.vocab)\n",
    "\n",
    "torch.manual_seed(RANDOM_SEED)\n",
    "model = RNN(INPUT_DIM, EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM)\n",
    "model = model.to(DEVICE)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Lv9Ny9di6VcI"
   },
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "T5t1Afn4xO11"
   },
   "outputs": [],
   "source": [
    "def compute_binary_accuracy(model, data_loader, device):\n",
    "    model.eval()\n",
    "    correct_pred, num_examples = 0, 0\n",
    "    with torch.no_grad():\n",
    "        for batch_idx, batch_data in enumerate(data_loader):\n",
    "            text, text_lengths = batch_data.review\n",
    "            logits = model(text, text_lengths)\n",
    "            predicted_labels = (torch.sigmoid(logits) > 0.5).long()\n",
    "            num_examples += batch_data.sentiment.size(0)\n",
    "            correct_pred += (predicted_labels.long() == batch_data.sentiment.long()).sum()\n",
    "        return correct_pred.float()/num_examples * 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1836
    },
    "colab_type": "code",
    "id": "EABZM8Vo0ilB",
    "outputId": "5d45e293-9909-4588-e793-8dfaf72e5c67"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001/015 | Batch 000/293 | Cost: 0.6948\n",
      "Epoch: 001/015 | Batch 050/293 | Cost: 0.6868\n",
      "Epoch: 001/015 | Batch 100/293 | Cost: 0.6926\n",
      "Epoch: 001/015 | Batch 150/293 | Cost: 0.6788\n",
      "Epoch: 001/015 | Batch 200/293 | Cost: 0.6838\n",
      "Epoch: 001/015 | Batch 250/293 | Cost: 0.6563\n",
      "training accuracy: 64.83%\n",
      "valid accuracy: 64.88%\n",
      "Time elapsed: 0.29 min\n",
      "Epoch: 002/015 | Batch 000/293 | Cost: 0.5635\n",
      "Epoch: 002/015 | Batch 050/293 | Cost: 0.6154\n",
      "Epoch: 002/015 | Batch 100/293 | Cost: 0.5449\n",
      "Epoch: 002/015 | Batch 150/293 | Cost: 0.6161\n",
      "Epoch: 002/015 | Batch 200/293 | Cost: 0.5794\n",
      "Epoch: 002/015 | Batch 250/293 | Cost: 0.5190\n",
      "training accuracy: 75.81%\n",
      "valid accuracy: 75.02%\n",
      "Time elapsed: 0.57 min\n",
      "Epoch: 003/015 | Batch 000/293 | Cost: 0.5194\n",
      "Epoch: 003/015 | Batch 050/293 | Cost: 0.4679\n",
      "Epoch: 003/015 | Batch 100/293 | Cost: 0.5069\n",
      "Epoch: 003/015 | Batch 150/293 | Cost: 0.4728\n",
      "Epoch: 003/015 | Batch 200/293 | Cost: 0.4180\n",
      "Epoch: 003/015 | Batch 250/293 | Cost: 0.3722\n",
      "training accuracy: 77.14%\n",
      "valid accuracy: 76.48%\n",
      "Time elapsed: 0.85 min\n",
      "Epoch: 004/015 | Batch 000/293 | Cost: 0.4978\n",
      "Epoch: 004/015 | Batch 050/293 | Cost: 0.4959\n",
      "Epoch: 004/015 | Batch 100/293 | Cost: 0.4877\n",
      "Epoch: 004/015 | Batch 150/293 | Cost: 0.4808\n",
      "Epoch: 004/015 | Batch 200/293 | Cost: 0.4264\n",
      "Epoch: 004/015 | Batch 250/293 | Cost: 0.3528\n",
      "training accuracy: 82.63%\n",
      "valid accuracy: 80.89%\n",
      "Time elapsed: 1.14 min\n",
      "Epoch: 005/015 | Batch 000/293 | Cost: 0.3676\n",
      "Epoch: 005/015 | Batch 050/293 | Cost: 0.3325\n",
      "Epoch: 005/015 | Batch 100/293 | Cost: 0.4878\n",
      "Epoch: 005/015 | Batch 150/293 | Cost: 0.4481\n",
      "Epoch: 005/015 | Batch 200/293 | Cost: 0.4147\n",
      "Epoch: 005/015 | Batch 250/293 | Cost: 0.4270\n",
      "training accuracy: 84.73%\n",
      "valid accuracy: 82.78%\n",
      "Time elapsed: 1.42 min\n",
      "Epoch: 006/015 | Batch 000/293 | Cost: 0.4143\n",
      "Epoch: 006/015 | Batch 050/293 | Cost: 0.4586\n",
      "Epoch: 006/015 | Batch 100/293 | Cost: 0.3946\n",
      "Epoch: 006/015 | Batch 150/293 | Cost: 0.3729\n",
      "Epoch: 006/015 | Batch 200/293 | Cost: 0.3584\n",
      "Epoch: 006/015 | Batch 250/293 | Cost: 0.4089\n",
      "training accuracy: 86.25%\n",
      "valid accuracy: 84.17%\n",
      "Time elapsed: 1.71 min\n",
      "Epoch: 007/015 | Batch 000/293 | Cost: 0.3147\n",
      "Epoch: 007/015 | Batch 050/293 | Cost: 0.3494\n",
      "Epoch: 007/015 | Batch 100/293 | Cost: 0.2743\n",
      "Epoch: 007/015 | Batch 150/293 | Cost: 0.3913\n",
      "Epoch: 007/015 | Batch 200/293 | Cost: 0.2999\n",
      "Epoch: 007/015 | Batch 250/293 | Cost: 0.2530\n",
      "training accuracy: 86.61%\n",
      "valid accuracy: 84.16%\n",
      "Time elapsed: 1.99 min\n",
      "Epoch: 008/015 | Batch 000/293 | Cost: 0.3180\n",
      "Epoch: 008/015 | Batch 050/293 | Cost: 0.3589\n",
      "Epoch: 008/015 | Batch 100/293 | Cost: 0.3230\n",
      "Epoch: 008/015 | Batch 150/293 | Cost: 0.3192\n",
      "Epoch: 008/015 | Batch 200/293 | Cost: 0.3328\n",
      "Epoch: 008/015 | Batch 250/293 | Cost: 0.2283\n",
      "training accuracy: 87.09%\n",
      "valid accuracy: 84.59%\n",
      "Time elapsed: 2.29 min\n",
      "Epoch: 009/015 | Batch 000/293 | Cost: 0.3429\n",
      "Epoch: 009/015 | Batch 050/293 | Cost: 0.3042\n",
      "Epoch: 009/015 | Batch 100/293 | Cost: 0.2704\n",
      "Epoch: 009/015 | Batch 150/293 | Cost: 0.2430\n",
      "Epoch: 009/015 | Batch 200/293 | Cost: 0.4137\n",
      "Epoch: 009/015 | Batch 250/293 | Cost: 0.1736\n",
      "training accuracy: 74.11%\n",
      "valid accuracy: 72.36%\n",
      "Time elapsed: 2.59 min\n",
      "Epoch: 010/015 | Batch 000/293 | Cost: 0.5759\n",
      "Epoch: 010/015 | Batch 050/293 | Cost: 0.4807\n",
      "Epoch: 010/015 | Batch 100/293 | Cost: 0.2686\n",
      "Epoch: 010/015 | Batch 150/293 | Cost: 0.3420\n",
      "Epoch: 010/015 | Batch 200/293 | Cost: 0.2759\n",
      "Epoch: 010/015 | Batch 250/293 | Cost: 0.3928\n",
      "training accuracy: 89.58%\n",
      "valid accuracy: 86.27%\n",
      "Time elapsed: 2.88 min\n",
      "Epoch: 011/015 | Batch 000/293 | Cost: 0.2417\n",
      "Epoch: 011/015 | Batch 050/293 | Cost: 0.3175\n",
      "Epoch: 011/015 | Batch 100/293 | Cost: 0.2029\n",
      "Epoch: 011/015 | Batch 150/293 | Cost: 0.2389\n",
      "Epoch: 011/015 | Batch 200/293 | Cost: 0.3107\n",
      "Epoch: 011/015 | Batch 250/293 | Cost: 0.3486\n",
      "training accuracy: 90.21%\n",
      "valid accuracy: 86.52%\n",
      "Time elapsed: 3.17 min\n",
      "Epoch: 012/015 | Batch 000/293 | Cost: 0.2540\n",
      "Epoch: 012/015 | Batch 050/293 | Cost: 0.2851\n",
      "Epoch: 012/015 | Batch 100/293 | Cost: 0.1901\n",
      "Epoch: 012/015 | Batch 150/293 | Cost: 0.2286\n",
      "Epoch: 012/015 | Batch 200/293 | Cost: 0.3239\n",
      "Epoch: 012/015 | Batch 250/293 | Cost: 0.2856\n",
      "training accuracy: 90.72%\n",
      "valid accuracy: 86.78%\n",
      "Time elapsed: 3.47 min\n",
      "Epoch: 013/015 | Batch 000/293 | Cost: 0.1913\n",
      "Epoch: 013/015 | Batch 050/293 | Cost: 0.2547\n",
      "Epoch: 013/015 | Batch 100/293 | Cost: 0.3984\n",
      "Epoch: 013/015 | Batch 150/293 | Cost: 0.2294\n",
      "Epoch: 013/015 | Batch 200/293 | Cost: 0.2692\n",
      "Epoch: 013/015 | Batch 250/293 | Cost: 0.2132\n",
      "training accuracy: 91.51%\n",
      "valid accuracy: 87.13%\n",
      "Time elapsed: 3.76 min\n",
      "Epoch: 014/015 | Batch 000/293 | Cost: 0.1699\n",
      "Epoch: 014/015 | Batch 050/293 | Cost: 0.2611\n",
      "Epoch: 014/015 | Batch 100/293 | Cost: 0.2594\n",
      "Epoch: 014/015 | Batch 150/293 | Cost: 0.2062\n",
      "Epoch: 014/015 | Batch 200/293 | Cost: 0.2608\n",
      "Epoch: 014/015 | Batch 250/293 | Cost: 0.2881\n",
      "training accuracy: 91.43%\n",
      "valid accuracy: 86.93%\n",
      "Time elapsed: 4.05 min\n",
      "Epoch: 015/015 | Batch 000/293 | Cost: 0.2522\n",
      "Epoch: 015/015 | Batch 050/293 | Cost: 0.2753\n",
      "Epoch: 015/015 | Batch 100/293 | Cost: 0.2322\n",
      "Epoch: 015/015 | Batch 150/293 | Cost: 0.2361\n",
      "Epoch: 015/015 | Batch 200/293 | Cost: 0.3728\n",
      "Epoch: 015/015 | Batch 250/293 | Cost: 0.2895\n",
      "training accuracy: 89.71%\n",
      "valid accuracy: 85.54%\n",
      "Time elapsed: 4.34 min\n",
      "Total Training Time: 4.34 min\n",
      "Test accuracy: 86.88%\n"
     ]
    }
   ],
   "source": [
    "start_time = time.time()\n",
    "\n",
    "for epoch in range(NUM_EPOCHS):\n",
    "    model.train()\n",
    "    for batch_idx, batch_data in enumerate(train_loader):\n",
    "        \n",
    "        text, text_lengths = batch_data.review\n",
    "        \n",
    "        ### FORWARD AND BACK PROP\n",
    "        logits = model(text, text_lengths)\n",
    "        cost = F.binary_cross_entropy_with_logits(logits, batch_data.sentiment)\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        cost.backward()\n",
    "        \n",
    "        ### UPDATE MODEL PARAMETERS\n",
    "        optimizer.step()\n",
    "        \n",
    "        ### LOGGING\n",
    "        if not batch_idx % 50:\n",
    "            print (f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} | '\n",
    "                   f'Batch {batch_idx:03d}/{len(train_loader):03d} | '\n",
    "                   f'Cost: {cost:.4f}')\n",
    "\n",
    "    with torch.set_grad_enabled(False):\n",
    "        print(f'training accuracy: '\n",
    "              f'{compute_binary_accuracy(model, train_loader, DEVICE):.2f}%'\n",
    "              f'\\nvalid accuracy: '\n",
    "              f'{compute_binary_accuracy(model, valid_loader, DEVICE):.2f}%')\n",
    "        \n",
    "    print(f'Time elapsed: {(time.time() - start_time)/60:.2f} min')\n",
    "    \n",
    "print(f'Total Training Time: {(time.time() - start_time)/60:.2f} min')\n",
    "print(f'Test accuracy: {compute_binary_accuracy(model, test_loader, DEVICE):.2f}%')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "jt55pscgFdKZ"
   },
   "outputs": [],
   "source": [
    "import spacy\n",
    "nlp = spacy.load('en')\n",
    "\n",
    "def predict_sentiment(model, sentence):\n",
    "    # based on:\n",
    "    # https://github.com/bentrevett/pytorch-sentiment-analysis/blob/\n",
    "    # master/2%20-%20Upgraded%20Sentiment%20Analysis.ipynb\n",
    "    model.eval()\n",
    "    tokenized = [tok.text for tok in nlp.tokenizer(sentence)]\n",
    "    indexed = [TEXT.vocab.stoi[t] for t in tokenized]\n",
    "    length = [len(indexed)]\n",
    "    tensor = torch.LongTensor(indexed).to(DEVICE)\n",
    "    tensor = tensor.unsqueeze(1)\n",
    "    length_tensor = torch.LongTensor(length)\n",
    "    prediction = torch.sigmoid(model(tensor, length_tensor))\n",
    "    return prediction.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "id": "O4__q0coFJyw",
    "outputId": "1a7f84ba-a977-455e-e248-3b7036d496d0"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Probability positive:\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.8258040845394135"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print('Probability positive:')\n",
    "1-predict_sentiment(model, \"This is such an awesome movie, I really love it!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Probability negative:\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.8462136387825012"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print('Probability negative:')\n",
    "predict_sentiment(model, \"I really hate this movie. It is really bad and sucks!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "7lRusB3dF80X"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch       1.0.1.post2\n",
      "pandas      0.23.4\n",
      "spacy       2.0.16\n",
      "\n"
     ]
    }
   ],
   "source": [
    "%watermark -iv"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "rnn_lstm_packed_imdb.ipynb",
   "provenance": [],
   "version": "0.3.2"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
